'''
Author: SlytherinGe
LastEditTime: 2021-11-03 16:37:11
'''
import mmcv
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
# (x, y) pos
TEST_POS = (100, 100)
IMG_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/'
IMG_ID = '000395'
ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/CAS-OpenSARShip/ship_detection_online/VOC2012/Annotations/'

def load_DFL_vec_to_numpy(img_id):

    # load DFL vector
    DFL_vec = mmcv.load('/home/gejunyao/fast_cache/DFL_Cache/'+ img_id + '.pkl')
    # transfer DFL from (vec, h, w) to (h, w, vec)
    DFL_vec = DFL_vec.permute(1, 2, 0).numpy()
    return DFL_vec

def get_similarity_map(test_pos, img_id):
    # load DFL vector
    DFL_vec = mmcv.load('/home/gejunyao/fast_cache/DFL_Cache/'+ img_id + '.pkl')
    # transfer DFL from (vec, h, w) to (h, w, vec)
    vec_len, h, w = DFL_vec.shape[0], DFL_vec.shape[1], DFL_vec.shape[2]
    DFL_vec = DFL_vec.permute(1, 2, 0).reshape(-1, vec_len)
    # normailize DFL vec
    norm_DFL_vec = DFL_vec #/ \
                   # (torch.norm(DFL_vec, dim=1, keepdim=True) + 1e-9)
    # get test vec
    index = test_pos[0] + test_pos[1]*h
    test_vec = DFL_vec[index]
    vec = torch.norm(norm_DFL_vec-test_vec, dim=1)

    # transform
    # vec = -torch.log10(1 - vec + 0.00001)
    vis_vec = vec.reshape(h, w, 1)

    return vis_vec


if __name__ == '__main__':

    anno_path = os.path.join(ANNO_ROOT, IMG_ID+'.xml')

    DFL_vec = load_DFL_vec_to_numpy(IMG_ID)

    anno_info, anno_obj = VL.voc_label_preprocess(anno_path)
